import torch
import torch.nn as nn

# CNN类定义
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 16,
                      kernel_size = (3, 3),
                      stride = (1, 1),
                      padding = 1),
            nn.MaxPool2d(kernel_size = 2),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32*7*7, 16),
            nn.ReLU(),
            nn.Linear(16, 10)
        )
    def forward(self, x):
        return self.net(x)

device = torch.device('cpu')
# 加载训练的pt模型
model = torch.load("./mnist.pt",map_location=device)
# 切换评估模式
model.eval()

# 指定输入输出节点名称
input_names = ['input']
output_names = ['output']
# 构造输入
x = torch.randn(1,1,28,28,requires_grad=True)
# onnx导出
torch.onnx.export(model, x, 'mnist.onnx', input_names=input_names, output_names=output_names, verbose='True')
